{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <font color='green'>1. Description<font>\n",
    "\n",
    "Sentiment classification using Amazon review dataset (multi class classification).\n",
    "Dataset can be downloaded from https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Books_v1_02.tsv.gz\n",
    "\n",
    "The consumer reviews serve as feedback for businesses in terms of performance, product quality, and consumer service. An online review typically consists of free-form text and a star rating out of 5. The problem of predicting a user’s star rating for a product, given the user’s text review for that product is lately become a popular, albeit hard, problem in machine learning. \n",
    "Using this dataset, we train a classifier to predict product rating based on the review text.\n",
    "\n",
    "Predicting the ratings based on the text is particulary difficult tasks. The primary reason for the difficulty is that two person can provide different ratings for writing similar reviews. As the scale for ratings increases (scale of 5 to scale of 10), the tasks become increasingly difficult."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <font color='green'>2. Data Preprocessing<font>\n",
    "\n",
    "For amazon review classification we will perform some data preparation and data cleaning steps. We will generate word2vec embeddings for review text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import time\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import OrderedDict\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.tokenize import word_tokenize\n",
    "from sklearn import metrics\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clean_review(review):\n",
    "    pattern = re.compile(r'[^a-zA-Z0-9 ]')\n",
    "    review = pattern.sub(' ', review)\n",
    "    return review\n",
    "\n",
    "def document_vector_frovedis(doc, frov_w2v_model, frov_vocab):\n",
    "    \"\"\"Create document vectors by averaging word vectors. Remove out-of-vocabulary words.\"\"\"\n",
    "    no_embedding = np.zeros(frov_w2v_model.shape[1])\n",
    "    vocab_doc = [word for word in doc if word in frov_vocab]\n",
    "    if len(vocab_doc) != 0:\n",
    "        return list(np.mean(frov_w2v_model.loc[vocab_doc], axis=0))\n",
    "    else:\n",
    "        return list(no_embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_w2v_embed(df):\n",
    "    from frovedis.exrpc.server import FrovedisServer\n",
    "    from frovedis.mllib.feature.w2v import Word2Vec as Frovedis_Word2Vec\n",
    "    os.environ[\"VE_OMP_NUM_THREADS\"] = '8'\n",
    "    FrovedisServer.initialize(\"mpirun -np 1 \" + os.environ[\"FROVEDIS_SERVER\"])\n",
    "    frovedis_w2v = Frovedis_Word2Vec(sentences = list(df[\"review_body\"]), hiddenSize=512, minCount=2, n_iter=100)\n",
    "    x_emb = frovedis_w2v.transform(list(df[\"review_body\"]), func = np.mean)\n",
    "    os.environ[\"VE_OMP_NUM_THREADS\"] = '1'\n",
    "    FrovedisServer.shut_down()\n",
    "    return pd.DataFrame(x_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_data(fname):\n",
    "    df = pd.read_csv(fname, sep='\\t', error_bad_lines=False)\n",
    "    print(df.shape)\n",
    "    df = df[[\"review_body\", \"star_rating\"]]\n",
    "    df = df.dropna().drop_duplicates().sample(frac=1, random_state=42)\n",
    "    df['review_body'] = df['review_body'].str.lower().apply(clean_review)\n",
    "    print(\"Dataset contains {} reviews\".format(df.shape[0]))\n",
    "    rating_categories = df[\"star_rating\"].value_counts()\n",
    "    stop = stopwords.words('english')\n",
    "    df['review_body'] = df['review_body'].apply(lambda x: [item for item in word_tokenize(x) if item not in stop])\n",
    "    x = create_w2v_embed(df)\n",
    "    x_train, x_test, y_train, y_test = train_test_split(x, df[\"star_rating\"], random_state=42)\n",
    "    return x_train, x_test, y_train, y_test, rating_categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "b'Skipping line 1680001: expected 15 fields, saw 22\\n'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3105370, 15)\n",
      "Dataset contains 3071369 reviews\n",
      "shape of train data: (2303526, 512)\n",
      "shape of test data: (767843, 512)\n"
     ]
    }
   ],
   "source": [
    "#---- Data Preparation ----\n",
    "# Please uncomment the below lines to download and unzip the dataset.\n",
    "#!wget -N https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Books_v1_02.tsv.gz\n",
    "#!gunzip amazon_reviews_us_Books_v1_02.tsv.gz\n",
    "#!mv amazon_reviews_us_Books_v1_02.tsv datasets\n",
    "\n",
    "DATA_FILE = \"datasets/amazon_reviews_us_Books_v1_02.tsv/amazon_reviews_us_Books_v1_02.tsv\"\n",
    "x_train, x_test, y_train, y_test, rating_categories = preprocess_data(DATA_FILE)\n",
    "print(\"shape of train data: {}\".format(x_train.shape))\n",
    "print(\"shape of test data: {}\".format(x_test.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEcCAYAAADdtCNzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdVUlEQVR4nO3deZxddZ3m8c9DEkAWZUlETAhBmsWoECEiCkrcIIhAt4IkgoKNZrSh0dFxGqZtUBgdbF+tjg0MRAkRHVkVjRAERlZZE/ZFAiFAkygmJOzQQuCZP84pc6mcqrpJ6tS5qXrer9d91T2/3znnfusmOU9+Z5VtIiIiulun6QIiIqIzJSAiIqJSAiIiIiolICIiolICIiIiKiUgIiKiUgIiBhVJV0v63EAvWy7/PknzVnf5ivVdKunw8v0Rkn7fj+s+VNLl/bW+GJwSENGRJD0i6cNN19FF0jckvSzp2fL1gKRTJG3ZNY/t62zv0Oa6ftbXfLb3tf2Tfqh9nCRLGt6y7v9re+81XXcMbgmIiPadZ3tjYDPg74A3Abe2hkR/UCH/NqNx+UsYaxVJm0q6WNISSU+W78d0m21bSbdIekbSryVt1rL87pJukPSUpDslTVrVGmy/bPte4BBgCfDVct2TJC1s+ax/krSoHHHMk/QhSZOB/wEcIuk5SXeW814t6VuSrgdeAN5SsctL5ajlaUn3S/pQS8drRlzdRinXlj+fKj/zPd13WUl6r6Q55brnSHpvS9/Vkk6SdH35u1wuaeSqfm+x9klAxNpmHeAsYGtgLPAicEq3eT4D/D2wJbAc+CGApNHAJcD/pBgF/DfgF5JGrU4htl8Bfg28r3ufpB2Ao4F3laOOfYBHbP8W+DbFaGQj2zu3LPZpYBqwMfBoxUe+G3gIGAmcAPyyNfx68f7y5yblZ97YrdbNKL6XHwKbA98DLpG0ectsnwI+C7wRWJfiu4tBbtAFhKQZkhZLuqfN+T8p6T5J90r6ed31xZqxvdT2L2y/YPtZ4FvAXt1m+6nte2w/D/wL8ElJw4DDgNm2Z9t+1fYVwFzgo2tQ0h8pwqa7V4D1gPGSRth+xPZDfaxrpu17bS+3/XJF/2LgB+UI5jxgHrDfGtTeZT/gQds/LT/7HOB+YP+Wec6y/YDtF4HzgQn98LnR4QZdQAAzgcntzChpO+A4YA/bbwO+XF9Z0R8kbSDpDEmPSnqGYvfJJmUAdHms5f2jwAiK/3VvDRxc7l56StJTwJ4UI43VNRpY1r3R9nyKv0/fABZLOlfSm/tY12N99C/ya++u+SjQ1zrb8WZWHrE8SvG7dXm85f0LwEb98LnR4QZdQNi+lm7/YCVtK+m3km6VdJ2kHcuuzwOn2n6yXHbxAJcbq+6rwA7Au22/nhW7T9Qyz1Yt78cCLwNPUGyAf2p7k5bXhrZPXp1CygPJ+wPXVfXb/rntPSmCycB3urp6WGVft1YeLan19xxLMYIBeB7YoKXvTauw3j+WNbYaCyzqY7kY5AZdQPRgOvCPtnel2Hd6Wtm+PbB9efDtpvIAYnSOEZLWb3kNp9g//yLFAdfNKPbFd3eYpPGSNgBOBC4sjxf8DNhf0j6ShpXrnFRxkLtXkoZLeitwDsWG+HsV8+wg6YOS1gP+s6z51bL7z8C41ThT6Y3AMZJGSDoYeCswu+y7A5hS9k0EDmpZbkn52W/pYb2zKf4dfKr83Q4BxgMXr2J9McgM+oCQtBHwXuACSXcAZ7Bil8JwYDtgEjAV+JGkTQa+yujBbIoNa9frG8APgNdRjAhuAn5bsdxPKXY1Pg6sDxwDYPsx4ECKs4iWUIwovkb7/w4OkfQc8DQwC1gK7Gr7jxXzrgecXNb5OMXG/biy74Ly51JJt7X52QA3U/x9fYLi2MtBtpeWff8CbAs8CXwT+OvxNNsvlPNfX+5a2711peU6PkYxOlsK/HfgY7afWIXaYhDSYHxgkKRxwMW23y7p9cA82yvtZ5Z0OnCz7bPK6d8Bx9qeM6AFR0R0oEE/grD9DPBwOSTvugip69TCX1GMHijP694eWNBAmRERHWfQBYSkc4AbgR0kLZR0JHAocGR5UdK9FLsZAC6jGObfB1wFfK1lyB4RMaQNyl1MERGx5gbdCCIiIvrH8L5nWXuMHDnS48aNa7qMiIi1xq233vqE7crbzQyqgBg3bhxz585tuoyIiLWGpKr7fgHZxRQRET1IQERERKUEREREVEpAREREpQRERERUSkBERESlBERERFRKQERERKUEREREVBpUV1KvqXHHXtJ0CQA8cnJ/PIc+ImLNZAQRERGVEhAREVEpAREREZUSEBERUSkBERERlRIQERFRqbbTXCXNAD4GLLb99or+rwGHttTxVmCU7WWSHgGeBV4BltueWFedERFRrc4RxExgck+dtr9re4LtCcBxwDW2l7XM8oGyP+EQEdGA2gLC9rXAsj5nLEwFzqmrloiIWHWNH4OQtAHFSOMXLc0GLpd0q6RpfSw/TdJcSXOXLFlSZ6kREUNK4wEB7A9c32330p62dwH2BY6S9P6eFrY93fZE2xNHjRpVd60REUNGJwTEFLrtXrK9qPy5GLgI2K2BuiIihrRGA0LSG4C9gF+3tG0oaeOu98DewD3NVBgRMXTVeZrrOcAkYKSkhcAJwAgA26eXs/0dcLnt51sW3QK4SFJXfT+3/du66oyIiGq1BYTtqW3MM5PidNjWtgXAzvVUFRER7eqEYxAREdGBEhAREVEpAREREZUSEBERUSkBERERlRIQERFRKQERERGVEhAREVEpAREREZUSEBERUSkBERERlRIQERFRKQERERGVEhAREVEpAREREZUSEBERUSkBERERlRIQERFRKQERERGVagsISTMkLZZ0Tw/9kyQ9LemO8nV8S99kSfMkzZd0bF01RkREz+ocQcwEJvcxz3W2J5SvEwEkDQNOBfYFxgNTJY2vsc6IiKhQW0DYvhZYthqL7gbMt73A9kvAucCB/VpcRET0qeljEO+RdKekSyW9rWwbDTzWMs/Csq2SpGmS5kqau2TJkjprjYgYUpoMiNuArW3vDPw78KvVWYnt6bYn2p44atSo/qwvImJIaywgbD9j+7ny/WxghKSRwCJgq5ZZx5RtERExgBoLCElvkqTy/W5lLUuBOcB2kraRtC4wBZjVVJ0REUPV8LpWLOkcYBIwUtJC4ARgBIDt04GDgC9KWg68CEyxbWC5pKOBy4BhwAzb99ZVZ0REVKstIGxP7aP/FOCUHvpmA7PrqCsiItrT9FlMERHRoRIQERFRKQERERGVEhAREVEpAREREZUSEBERUSkBERERlRIQERFRKQERERGVEhAREVEpAREREZUSEBERUSkBERERlRIQERFRKQERERGVEhAREVEpAREREZX6DAhJG0pap3y/vaQDJI2ov7SIiGhSOyOIa4H1JY0GLgc+Dcyss6iIiGheOwEh2y8AHwdOs30w8LY+F5JmSFos6Z4e+g+VdJekuyXdIGnnlr5HyvY7JM1t95eJiIj+01ZASHoPcChwSdk2rI3lZgKTe+l/GNjL9juAk4Dp3fo/YHuC7YltfFZERPSz4W3M82XgOOAi2/dKegtwVV8L2b5W0rhe+m9ombwJGNNGLRERMUD6DAjb1wDXSNqgnF4AHNPPdRwJXNr6scDlkgycYbv76OKvJE0DpgGMHTu2n8uKiBi62jmL6T2S7gPuL6d3lnRafxUg6QMUAfFPLc172t4F2Bc4StL7e1re9nTbE21PHDVqVH+VFREx5LVzDOIHwD7AUgDbdwI9brBXhaSdgB8DB9pe2tVue1H5czFwEbBbf3xeRES0r60L5Ww/1q3plTX9YEljgV8Cn7b9QEv7hpI27noP7A1UngkVERH1aecg9WOS3gu4vEDuS8Af+lpI0jnAJGCkpIXACcAIANunA8cDmwOnSQJYXp6xtAVwUdk2HPi57d+u4u8VERFrqJ2A+ALwv4HRwCKKi+WO6msh21P76P8c8LmK9gXAzisvERERA6mdgHjO9qG1VxIRER2lnYC4R9KfgevK1+9tP11vWRER0bQ+D1Lb/htgKnA3sB9wp6Q7aq4rIiIa1ucIQtIYYA/gfRTHBu4Ffl9zXRER0bB2djH9BzAH+LbtL9RcT0REdIh2roN4J3A28ClJN0o6W9KRNdcVERENa+deTHdKegh4iGI302HAXsCZNdcWERENaucYxFxgPeAGirOY3m/70boLi4iIZrVzDGJf20tqryQiIjpKO8cg1pF0pqRLASSNzzGIiIjBr52AmAlcBry5nH6A4iFCERExiLUTECNtnw+8CmB7Of1wN9eIiOhs7QTE85I2p3jKG5J2B3KrjYiIQa6dg9RfAWYB20q6HhgFHFRrVRER0bh2roO4TdJewA6AgHm2X669soiIaFSPASHpg7avlPTxbl3bS8L2L2uuLSIiGtTbCGIv4Epg/4o+UzwuNCIiBqkeA8L2CeXbz9nOWUsREUNMO2cxPSxpuqQPqXxQdEREDH7tBMSOwP+jeA71w5JOkbRnOyuXNEPSYkn39NAvST+UNF/SXZJ2aek7XNKD5evwdj4vIiL6TztPlHvB9vm2P05x6+/XA9e0uf6ZwORe+vcFtitf04D/AyBpM+AE4N3AbsAJkjZt8zMjIqIftDOCQNJekk4DbgXWBz7ZznK2rwWW9TLLgcDZLtwEbCJpS2Af4Arby2w/CVxB70ETERH9rJ3bfT8C3A6cD3zN9vP9+PmjgcdapheWbT21R0TEAGnnSuqdbD9TeyWrSdI0it1TjB07tuFqIiIGj3Z2Mb1J0u+6DjRL2knS1/vp8xcBW7VMjynbempfie3ptifanjhq1Kh+KisiItoJiB8BxwEvA9i+C5jST58/C/hMeTbT7sDTtv9EcXvxvSVtWh6c3rtsi4iIAdLOLqYNbN/S7RKI5e2sXNI5wCRgpKSFFGcmjQCwfTowG/goMB94Afhs2bdM0knAnHJVJ9ru7WB3RET0s3YC4glJ27Lidt8HAX9qZ+W2p/bRb4rrK6r6ZgAz2vmciIjof+0ExFHAdGBHSYuAh4FDa60qIiIa12tASBoG/IPtD0vaEFjH9rMDU1pERDSp14Cw/UrXbTX6+fqHiIjocO3sYrpd0izgAuCvIZHnQUREDG7tBMT6wFLggy1teR5ERMQg184jRz87EIVERERnaetmfRERMfQkICIiolKPASHpS+XPPQaunIiI6BS9jSC6jj38+0AUEhERnaW3g9R/kPQg8GZJd7W0i+IuGTvVW1pERDSpx4CwPVXSmyjuonrAwJUUERGdoK8rqR8Hdpa0LrB92TzP9su1VxYREY1q55GjewFnA49Q7F7aStLh5fOmIyJikGrnSurvAXvbngcgaXvgHGDXOguLiIhmtXMdxIiucACw/QDlQ38iImLwamcEMVfSj4GfldOHAnPrKykiIjpBOwHxRYqHBh1TTl8HnFZbRRER0RHauVnfXyiOQ3yv/nIiIqJT1HovJkmTJc2TNF/SsRX935d0R/l6QNJTLX2vtPTNqrPOiIhYWTu7mFZL+bjSU4GPAAuBOZJm2b6vax7b/7Vl/n8E3tmyihdtT6irvoiI6F2dI4jdgPm2F9h+CTgXOLCX+adSnD4bEREdYLUCQtK0NmYbDTzWMr2wbKta39bANsCVLc3rS5or6SZJf7s6dUZExOpb3V1M6tcqYApwoe1XWtq2tr1I0luAKyXdbfuhlQopwmoawNixY/u5rIiIoWu1RhC2z2hjtkXAVi3TY8q2KlPotnvJ9qLy5wLgal57fKJ1vum2J9qeOGrUqDbKioiIdvQZEJLGSLpI0hJJiyX9QtKYNtY9B9hO0jblzf6mACudjSRpR2BT4MaWtk0lrVe+HwnsAdzXfdmIiKhPOyOIsyg27FsCbwZ+U7b1yvZy4GiK24X/ATjf9r2STpTUevvwKcC5tt3S9laKK7jvBK4CTm49+ykiIurXzjGIUbZbA2GmpC+3s3Lbs4HZ3dqO7zb9jYrlbgDe0c5nREREPdoZQSyVdJikYeXrMGBp3YVFRESz2gmIvwc+CTwO/Ak4iBXPq46IiEGqnXsxPUoeORoRMeT0GBCSju+pD7Dtk2qoJyIiOkRvI4jnK9o2BI4ENgcSEBERg1iPAWH737reS9oY+BLFsYdzgX/rabmIiBgcej0GIWkz4CsUT5H7CbCL7ScHorCIiGhWb8cgvgt8HJgOvMP2cwNWVURENK6301y/SnHl9NeBP0p6pnw9K+mZgSkvIiKa0tsxiFqfNhcREZ0tIRAREZUSEBERUSkBERERlRIQERFRaXUfORqD3LhjL2m6BAAeOXm/pkuIGLIygoiIiEoJiIiIqJSAiIiISgmIiIioVGtASJosaZ6k+ZKOreg/QtISSXeUr8+19B0u6cHydXiddUZExMpqO4tJ0jDgVOAjwEJgjqRZtu/rNut5to/utuxmwAnARMDAreWyuZNsRMQAqXMEsRsw3/YC2y9RPEfiwDaX3Qe4wvayMhSuACbXVGdERFSoMyBGA4+1TC8s27r7hKS7JF0oaatVXDYiImrS9EHq3wDjbO9EMUr4yaquQNI0SXMlzV2yZEm/FxgRMVTVGRCLgK1apseUbX9le6ntv5STPwZ2bXfZlnVMtz3R9sRRo0b1S+EREVFvQMwBtpO0jaR1gSnArNYZJG3ZMnkA8Ify/WXA3pI2lbQpsHfZFhERA6S2s5hsL5d0NMWGfRgww/a9kk4E5tqeBRwj6QBgObAMOKJcdpmkkyhCBuBE28vqqjUiIlZW6836bM8GZndrO77l/XHAcT0sOwOYUWd9ERHRs6YPUkdERIdKQERERKUEREREVEpAREREpQRERERUSkBERESlBERERFRKQERERKUEREREVEpAREREpQRERERUSkBERESlBERERFRKQERERKUEREREVEpAREREpQRERERUSkBERESlBERERFRKQERERKVaA0LSZEnzJM2XdGxF/1ck3SfpLkm/k7R1S98rku4oX7PqrDMiIlY2vK4VSxoGnAp8BFgIzJE0y/Z9LbPdDky0/YKkLwL/ChxS9r1oe0Jd9UVERO/qHEHsBsy3vcD2S8C5wIGtM9i+yvYL5eRNwJga64mIiFVQ2wgCGA081jK9EHh3L/MfCVzaMr2+pLnAcuBk27+qWkjSNGAawNixY9ek3ohK4469pOkSAHjk5P2aLiGGmDoDom2SDgMmAnu1NG9te5GktwBXSrrb9kPdl7U9HZgOMHHiRA9IwRFDVMJyaKlzF9MiYKuW6TFl22tI+jDwz8ABtv/S1W57UflzAXA18M4aa42IiG7qDIg5wHaStpG0LjAFeM3ZSJLeCZxBEQ6LW9o3lbRe+X4ksAfQenA7IiJqVtsuJtvLJR0NXAYMA2bYvlfSicBc27OA7wIbARdIAvgP2wcAbwXOkPQqRYid3O3sp4iIqFmtxyBszwZmd2s7vuX9h3tY7gbgHXXWFhERvcuV1BERUSkBERERlRIQERFRqSOug4iIWNsMhWtCMoKIiIhKCYiIiKiUgIiIiEoJiIiIqJSAiIiISgmIiIiolICIiIhKCYiIiKiUgIiIiEoJiIiIqJSAiIiISgmIiIiolICIiIhKCYiIiKiUgIiIiEq1BoSkyZLmSZov6diK/vUknVf23yxpXEvfcWX7PEn71FlnRESsrLaAkDQMOBXYFxgPTJU0vttsRwJP2v4b4PvAd8plxwNTgLcBk4HTyvVFRMQAqXMEsRsw3/YC2y8B5wIHdpvnQOAn5fsLgQ9JUtl+ru2/2H4YmF+uLyIiBkidjxwdDTzWMr0QeHdP89heLulpYPOy/aZuy46u+hBJ04Bp5eRzkuateelrZCTwxJqsQN/pp0qal+9ihXwXK+S7WKETvoute+pY659JbXs6ML3pOrpImmt7YtN1dIJ8Fyvku1gh38UKnf5d1LmLaRGwVcv0mLKtch5Jw4E3AEvbXDYiImpUZ0DMAbaTtI2kdSkOOs/qNs8s4PDy/UHAlbZdtk8pz3LaBtgOuKXGWiMiopvadjGVxxSOBi4DhgEzbN8r6URgru1ZwJnATyXNB5ZRhAjlfOcD9wHLgaNsv1JXrf2sY3Z3dYB8Fyvku1gh38UKHf1dqPgPe0RExGvlSuqIiKiUgIiIiEoJiIiIqJSAiH4laTNJmzVdR6fI9xFrswREP5C0haRdytcWTdcz0CSNlXSupCXAzcAtkhaXbeMaLm/A5fuI3qxN24ucxbQGJE0ATqe4wK/rQr4xwFPAP9i+rZnKBpakG4EfABd2nY5c3lzxYODLtndvsLwBl+9jZeWGsOt2OYts/7nJepqwNm4vEhBrQNIdwH+xfXO39t2BM2zv3EhhA0zSg7a3W9W+wSrfxwpr40axLmvj9mKtvxdTwzbs/ocNYPsmSRs2UVBDbpV0GsWdebtu0LgVxVXytzdWVXPyfawwk543imcBHbdRrNFat73ICGINSPohsC1wNq/dEHwGeNj20U3VNpDKW6kcSXGb9q7dCAuB3wBn2v5LU7U1Id/HCn2MpuaXz4IZEtbG7UUCYg1J2pfXbggWAbNsz26uqojOsDZuFOu0tm0vEhBRK0kfs31x03V0iqH4faxtG8VYIccgaiJpWvmsiqHuXcCQ2iD2Ych9H7YvBS5tuo5O1qnbi1wHUR81XUCTJJ0NYPuEpmtpgqTdJL2rfD9e0lckfXSofh9VyqdBRqEjtxcZQfQjSXtSPDv7HttnNF3PQJHU/TkfAj4gaRMA2wcMeFENknQCsC8wXNIVFI/avQo4VtI7bX+r0QI7R0duFOskaUeKXW03236upevRhkrqVY5BrAFJt9jerXz/eeAo4CJgb+A3tk9usr6BIuk2imd3/BgwxT/8c1jxfI9rmqtu4Em6G5gArAc8Doyx/Yyk11FsGHZqsr5OIemzts9quo6BIukYim3EHyj+fnzJ9q/Lvtts79JgeZWyi2nNjGh5Pw34iO1vUgTEoc2U1IiJwK3APwNP274aeNH2NUMtHErLbb9i+wXgIdvPANh+EXi12dI6yjebLmCAfR7Y1fbfApOAf5H0pbKvI0dT2cW0ZtaRtClF0Mr2EgDbz0ta3mxpA8f2q8D3JV1Q/vwzQ/vv1kuSNigDYteuRklvYIgFhKS7euoCOvo+RDVYp2u3ku1HJE0CLpS0NQmIQekNFP9zFmBJW9r+k6SN6NA/8DrZXggcLGk/4Jmm62nQ+7suhivDs8sIVjyDfajYAtgHeLJbu4AbBr6cRv1Z0gTbdwDYfk7Sx4AZwDsarawHOQZRA0kbAFvYfrjpWiKaJOlM4Czbv6/o+7ntTzVQViMkjaHY/fh4Rd8etq9voKxeJSAiIqJSDlJHRESlBERERFRKQMSQJ+kVSXdIukfSb7ou8Otl/gmSPtoyfYCkY/uplpmSDqponyTp4vL9EZK+0R+fF9GbBEREcc3GBNtvB5ZRXMzUmwnAXwPC9qyhclFkDC0JiIjXupHyrqPl/ZRulHS7pBsk7VA+6+FE4JBy1HFI+T/6U8plZkr6YTn/gq7RgKR1JJ0m6X5JV0iaXTVSaCVpcjn/bcDHW7peBJ4r5zm4HPncKena/v86YijLdRARpfK50R8Cziyb7gfeZ3u5pA8D37b9CUnHAxO7nmUg6Yhuq9oS2BPYEZgFXEixgR8HjAfeSHG7hRm91LI+8CPgg8B84LyuPtvntcx6PLCP7UV97RqLWFUZQUTA68rnBT9OcWHXFWX7G4ALJN0DfB94W5vr+5XtV23fx4qrhfcELijbH6e4eV9vdqR4oM6DLs5F/1kP810PzCzvBTaszfoi2pKAiCiPQQBdtzzoOgZxEnBVeWxif2D9NtfX+kjRWq+ot/0F4OsUT2m7VdLmdX5eDC0JiIhSee+kY4CvShpOMYJYVHYf0TLrs8DGq7j664FPlMcitqC4WVtv7gfGSdq2nJ5aNZOkbW3fbPt4YAlFUET0iwRERAvbtwN3UWyQ/xX4X5Ju57XH664CxncdpG5z1b8AFlLcFv1nwG3A073U8Z8Udwi+pDxIvbiHWb8r6e5yN9gNwJ1t1hPRp9xqI2KASNqovEHb5sAtwB5V9+WJ6BQ5iyli4Fxcnmm0LnBSwiE6XUYQERFRKccgIiKiUgIiIiIqJSAiIqJSAiIiIiolICIiotL/B9XETXPMSZ8dAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Label distribution summary\n",
    "ax = rating_categories.plot(kind='bar', title='Label Distribution').\\\n",
    "     set(xlabel=\"Rating Id's\", ylabel=\"No. of reviewes\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <font color='green'> 3. Algorithm Evaluation<font>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_time = []\n",
    "test_time = []\n",
    "accuracy = []\n",
    "precision = []\n",
    "recall = []\n",
    "f1 = []\n",
    "estimator_name = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(estimator, estimator_nm,\n",
    "             x_train, y_train,\n",
    "             x_test, y_test):\n",
    "    estimator_name.append(estimator_nm)\n",
    "\n",
    "    start_time = time.time()\n",
    "    estimator.fit(x_train, y_train)\n",
    "    train_time.append(round(time.time() - start_time, 4))\n",
    "\n",
    "    start_time = time.time()\n",
    "    pred_y = estimator.predict(x_test)\n",
    "    test_time.append(round(time.time() - start_time, 4))\n",
    "\n",
    "    accuracy.append(metrics.accuracy_score(y_test, pred_y))\n",
    "    precision.append(metrics.precision_score(y_test, pred_y, average='macro'))\n",
    "    recall.append(metrics.recall_score(y_test, pred_y, average='macro'))\n",
    "    f1.append(metrics.f1_score(y_test, pred_y, average='macro'))\n",
    "\n",
    "    target_names = ['rating 1.0', 'rating 2.0', 'rating 3.0', 'rating 4.0', 'rating 5.0']\n",
    "    return metrics.classification_report(y_test, pred_y, target_names=target_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.1 Bernoulli Naive Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Frovedis Bernoulli Naive Bayes metrices: \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "  rating 1.0       0.34      0.49      0.40     59017\n",
      "  rating 2.0       0.17      0.31      0.22     41354\n",
      "  rating 3.0       0.20      0.22      0.21     61591\n",
      "  rating 4.0       0.28      0.31      0.30    144753\n",
      "  rating 5.0       0.78      0.65      0.71    461128\n",
      "\n",
      "    accuracy                           0.52    767843\n",
      "   macro avg       0.35      0.39      0.37    767843\n",
      "weighted avg       0.57      0.52      0.54    767843\n",
      "\n",
      "Sklearn Bernoulli Naive Bayes metrices: \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "  rating 1.0       0.34      0.49      0.40     59017\n",
      "  rating 2.0       0.17      0.31      0.22     41354\n",
      "  rating 3.0       0.20      0.22      0.21     61591\n",
      "  rating 4.0       0.28      0.31      0.30    144753\n",
      "  rating 5.0       0.78      0.65      0.71    461128\n",
      "\n",
      "    accuracy                           0.52    767843\n",
      "   macro avg       0.35      0.39      0.37    767843\n",
      "weighted avg       0.57      0.52      0.54    767843\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Demo: Bernoulli Naive Bayes\n",
    "import frovedis\n",
    "TARGET = \"bernoulli_naive_bayes\"\n",
    "from frovedis.exrpc.server import FrovedisServer\n",
    "FrovedisServer.initialize(\"mpirun -np 8 \" + os.environ[\"FROVEDIS_SERVER\"])\n",
    "from frovedis.mllib.naive_bayes import BernoulliNB as frovNB\n",
    "\n",
    "f_est = frovNB(alpha=1.0)\n",
    "E_NM = TARGET + \"_frovedis_\" + frovedis.__version__\n",
    "f_report = evaluate(f_est, E_NM, \\\n",
    "                    x_train, y_train, x_test, y_test)\n",
    "f_est.release()\n",
    "FrovedisServer.shut_down()\n",
    "\n",
    "import sklearn\n",
    "from sklearn.naive_bayes import BernoulliNB as skNB\n",
    "s_est = skNB(alpha=1.0)\n",
    "E_NM = TARGET + \"_sklearn_\" + sklearn.__version__\n",
    "s_report = evaluate(s_est, E_NM, \\\n",
    "                    x_train, y_train, x_test, y_test)\n",
    "# Precision, Recall and F1 score for each class\n",
    "print(\"Frovedis Bernoulli Naive Bayes metrices: \")\n",
    "print(f_report)\n",
    "print(\"Sklearn Bernoulli Naive Bayes metrices: \")\n",
    "print(s_report)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.2 Decision Tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/adityaw/virt1/lib64/python3.6/site-packages/sklearn/metrics/_classification.py:1245: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/adityaw/virt1/lib64/python3.6/site-packages/sklearn/metrics/_classification.py:1245: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/adityaw/virt1/lib64/python3.6/site-packages/sklearn/metrics/_classification.py:1245: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Frovedis Decision Tree metrices: \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "  rating 1.0       0.37      0.27      0.31     59017\n",
      "  rating 2.0       0.22      0.01      0.01     41354\n",
      "  rating 3.0       0.23      0.04      0.07     61591\n",
      "  rating 4.0       0.26      0.03      0.06    144753\n",
      "  rating 5.0       0.64      0.96      0.77    461128\n",
      "\n",
      "    accuracy                           0.61    767843\n",
      "   macro avg       0.34      0.26      0.24    767843\n",
      "weighted avg       0.49      0.61      0.50    767843\n",
      "\n",
      "Sklearn Decision Tree metrices: \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "  rating 1.0       0.00      0.00      0.00     59017\n",
      "  rating 2.0       0.00      0.00      0.00     41354\n",
      "  rating 3.0       0.00      0.00      0.00     61591\n",
      "  rating 4.0       0.00      0.00      0.00    144753\n",
      "  rating 5.0       0.60      1.00      0.75    461128\n",
      "\n",
      "    accuracy                           0.60    767843\n",
      "   macro avg       0.12      0.20      0.15    767843\n",
      "weighted avg       0.36      0.60      0.45    767843\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/adityaw/virt1/lib64/python3.6/site-packages/sklearn/metrics/_classification.py:1245: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "# Demo: DecisionTreeClassifier\n",
    "import frovedis\n",
    "TARGET = \"decision_tree\"\n",
    "from frovedis.exrpc.server import FrovedisServer\n",
    "FrovedisServer.initialize(\"mpirun -np 8 \" + os.environ[\"FROVEDIS_SERVER\"])\n",
    "from frovedis.mllib.tree import DecisionTreeClassifier as frovDecisionTreeClassifier\n",
    "\n",
    "f_est = frovDecisionTreeClassifier(max_leaf_nodes=5, max_depth=12)\n",
    "E_NM = TARGET + \"_frovedis_\" + frovedis.__version__\n",
    "f_report = evaluate(f_est, E_NM, \\\n",
    "                    x_train, y_train, x_test, y_test)\n",
    "f_est.release()\n",
    "FrovedisServer.shut_down()\n",
    "\n",
    "import sklearn\n",
    "from sklearn.tree import DecisionTreeClassifier as skDecisionTreeClassifier\n",
    "s_est = skDecisionTreeClassifier(max_leaf_nodes=5, max_depth=12)\n",
    "E_NM = TARGET + \"_sklearn_\" + sklearn.__version__\n",
    "s_report = evaluate(s_est, E_NM, \\\n",
    "                    x_train, y_train, x_test, y_test)\n",
    "\n",
    "# Precision, Recall and F1 score for each class\n",
    "print(\"Frovedis Decision Tree metrices: \")\n",
    "print(f_report)\n",
    "print(\"Sklearn Decision Tree metrices: \")\n",
    "print(s_report)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <font color='green'> 4. Performance summary<font>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>estimator</th>\n",
       "      <th>train time</th>\n",
       "      <th>test time</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>f1-score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>bernoulli_naive_bayes_frovedis_0.9.10</td>\n",
       "      <td>3.8471</td>\n",
       "      <td>2.5516</td>\n",
       "      <td>0.518564</td>\n",
       "      <td>0.354237</td>\n",
       "      <td>0.394694</td>\n",
       "      <td>0.366431</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>bernoulli_naive_bayes_sklearn_0.24.1</td>\n",
       "      <td>19.9974</td>\n",
       "      <td>6.7362</td>\n",
       "      <td>0.518564</td>\n",
       "      <td>0.354237</td>\n",
       "      <td>0.394694</td>\n",
       "      <td>0.366431</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>decision_tree_frovedis_0.9.10</td>\n",
       "      <td>35.7016</td>\n",
       "      <td>3.6518</td>\n",
       "      <td>0.609019</td>\n",
       "      <td>0.344660</td>\n",
       "      <td>0.261983</td>\n",
       "      <td>0.243691</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>decision_tree_sklearn_0.24.1</td>\n",
       "      <td>730.6046</td>\n",
       "      <td>1.5582</td>\n",
       "      <td>0.600550</td>\n",
       "      <td>0.120110</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>0.150086</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                               estimator  train time  test time  accuracy  \\\n",
       "0  bernoulli_naive_bayes_frovedis_0.9.10      3.8471     2.5516  0.518564   \n",
       "1   bernoulli_naive_bayes_sklearn_0.24.1     19.9974     6.7362  0.518564   \n",
       "2          decision_tree_frovedis_0.9.10     35.7016     3.6518  0.609019   \n",
       "3           decision_tree_sklearn_0.24.1    730.6046     1.5582  0.600550   \n",
       "\n",
       "   precision    recall  f1-score  \n",
       "0   0.354237  0.394694  0.366431  \n",
       "1   0.354237  0.394694  0.366431  \n",
       "2   0.344660  0.261983  0.243691  \n",
       "3   0.120110  0.200000  0.150086  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summary = pd.DataFrame(OrderedDict({ \"estimator\": estimator_name,\n",
    "                                     \"train time\": train_time,\n",
    "                                     \"test time\": test_time,\n",
    "                                     \"accuracy\": accuracy,\n",
    "                                     \"precision\": precision,\n",
    "                                     \"recall\": recall,\n",
    "                                     \"f1-score\": f1\n",
    "                                  }))\n",
    "summary"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
